{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b44aa468",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import json\n",
    "import time\n",
    "import random\n",
    "from openai import OpenAI\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "from items import Item \n",
    "from testing import Tester  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a8348f3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = OpenAI(api_key=os.getenv(\"OPENAI_API_KEY\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f4cf994b",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_MODEL = \"gpt-4o-mini-2024-07-18\"\n",
    "EPOCHS = 5\n",
    "BATCH_SIZE = 8\n",
    "LR_MULT = 0.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f0ac12",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"train.pkl\", \"rb\") as f:\n",
    "    train_subset = pickle.load(f)\n",
    "\n",
    "with open(\"test.pkl\", \"rb\") as f:\n",
    "    val_subset = pickle.load(f) \n",
    "\n",
    "print(f\"Loaded {len(train_subset)} training and {len(val_subset)} validation items.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aad35ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.shuffle(train_subset)\n",
    "random.shuffle(val_subset)\n",
    "\n",
    "TRAIN_LIMIT = 5000 \n",
    "VAL_LIMIT = 1000    \n",
    "\n",
    "train_subset = train_subset[:TRAIN_LIMIT]\n",
    "val_subset = val_subset[:VAL_LIMIT]\n",
    "\n",
    "print(f\"Using {len(train_subset)} training and {len(val_subset)} validation samples.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2c85d07e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(item):\n",
    "    return f\"\"\"\n",
    "### CONTEXT\n",
    "You are a price estimation assistant for e-commerce listings.\n",
    "Each product is described by its title, category, key features, and details.\n",
    "\n",
    "### TASK\n",
    "Estimate the most likely retail price in USD.\n",
    "Think step-by-step about product type, quality, and included components \n",
    "before stating the final answer as \"Predicted Price: $<amount>\".\n",
    "\n",
    "### EXAMPLES\n",
    "- Wireless earbuds with active noise cancellation -> Predicted Price: $89\n",
    "- Stainless steel kitchen knife set (6-piece) -> Predicted Price: $45\n",
    "- Laptop stand aluminum adjustable -> Predicted Price: $32\n",
    "\n",
    "### PRODUCT TITLE\n",
    "{item.title}\n",
    "\n",
    "### CATEGORY\n",
    "{item.category}\n",
    "\n",
    "### DETAILS\n",
    "{item.details}\n",
    "\n",
    "### YOUR REASONING\n",
    "(Think about product quality, features, and typical market range.)\n",
    "\n",
    "### FINAL ANSWER\n",
    "Predicted Price: $\n",
    "\"\"\"\n",
    "\n",
    "def build_completion(item):\n",
    "    return f\"Predicted Price: ${round(item.price)}.00\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83eca819",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_jsonl(data, filename):\n",
    "    with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
    "        for item in data:\n",
    "            if getattr(item, \"include\", True):\n",
    "                prompt = build_prompt(item)\n",
    "                completion = build_completion(item)\n",
    "                json_obj = {\n",
    "                    \"messages\": [\n",
    "                        {\"role\": \"user\", \"content\": prompt},\n",
    "                        {\"role\": \"assistant\", \"content\": completion}\n",
    "                    ]\n",
    "                }\n",
    "                f.write(json.dumps(json_obj) + \"\\n\")\n",
    "    print(f\"Wrote {len(data)} samples to {filename}\")\n",
    "\n",
    "TRAIN_JSONL = \"train_prepared.jsonl\"\n",
    "VAL_JSONL = \"val_prepared.jsonl\"\n",
    "\n",
    "write_jsonl(train_subset, TRAIN_JSONL)\n",
    "write_jsonl(val_subset, VAL_JSONL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d25bfc85",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file = client.files.create(file=open(TRAIN_JSONL, \"rb\"), purpose=\"fine-tune\")\n",
    "val_file = client.files.create(file=open(VAL_JSONL, \"rb\"), purpose=\"fine-tune\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4e092a",
   "metadata": {},
   "outputs": [],
   "source": [
    "job = client.fine_tuning.jobs.create(\n",
    "    training_file=train_file.id,\n",
    "    validation_file=val_file.id,\n",
    "    model=BASE_MODEL,\n",
    "    hyperparameters={\n",
    "        \"n_epochs\": EPOCHS,\n",
    "        \"batch_size\": BATCH_SIZE,\n",
    "        \"learning_rate_multiplier\": LR_MULT\n",
    "    }\n",
    ")\n",
    "\n",
    "print(\"Job started:\", job.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d58b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_finetune_events(job_id, poll_interval=30):\n",
    "    print(f\"Tracking fine-tuning job: {job_id}\\n\")\n",
    "    seen = set()\n",
    "    loss_data = []\n",
    "    \n",
    "    while True:\n",
    "        job = client.fine_tuning.jobs.retrieve(job_id)\n",
    "        events = client.fine_tuning.jobs.list_events(job_id)\n",
    "        \n",
    "        for e in events.data[::-1]:\n",
    "            if e.id not in seen:\n",
    "                seen.add(e.id)\n",
    "                ts = datetime.fromtimestamp(e.created_at)\n",
    "                msg = e.message\n",
    "                print(f\"[{ts:%Y-%m-%d %H:%M:%S}] {msg}\")\n",
    "                \n",
    "                if \"training_loss\" in msg:\n",
    "                    try:\n",
    "                        step = int(msg.split(\"Step \")[1].split(\"/\")[0])\n",
    "                        train_loss = float(msg.split(\"training_loss: \")[1].split(\",\")[0])\n",
    "                        val_loss = None\n",
    "                        if \"val_loss\" in msg:\n",
    "                            val_loss = float(msg.split(\"val_loss: \")[1].split(\",\")[0])\n",
    "                        loss_data.append((step, train_loss, val_loss))\n",
    "                    except Exception:\n",
    "                        pass\n",
    "        \n",
    "        if job.status == \"succeeded\":\n",
    "            print(\"\\nFine-tuning complete!\")\n",
    "            print(\"Fine-tuned model ID:\", job.fine_tuned_model)\n",
    "            \n",
    "            if loss_data:\n",
    "                steps = [d[0] for d in loss_data]\n",
    "                train_losses = [d[1] for d in loss_data]\n",
    "                val_losses = [d[2] for d in loss_data if d[2] is not None]\n",
    "\n",
    "                plt.figure(figsize=(8, 5))\n",
    "                plt.plot(steps, train_losses, marker=\"o\", color=\"teal\", label=\"Training Loss\")\n",
    "                if val_losses:\n",
    "                    plt.plot(steps[:len(val_losses)], val_losses, marker=\"o\", color=\"orange\", label=\"Validation Loss\")\n",
    "                plt.xlabel(\"Step\")\n",
    "                plt.ylabel(\"Loss\")\n",
    "                plt.title(f\"Fine-Tuning Progress — {job_id}\")\n",
    "                plt.legend()\n",
    "                plt.grid(alpha=0.3)\n",
    "                plt.show()\n",
    "            else:\n",
    "                print(\"No loss data found. Fine-tuning may have completed too quickly to log metrics.\")\n",
    "\n",
    "            return job.fine_tuned_model\n",
    "\n",
    "        elif job.status in [\"failed\", \"cancelled\"]:\n",
    "            print(f\"\\nFine-tuning {job.status}.\")\n",
    "            if job.error:\n",
    "                print(\"Error:\", job.error)\n",
    "            return None\n",
    "\n",
    "        time.sleep(poll_interval)\n",
    "\n",
    "MODEL_ID = stream_finetune_events(job.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ed9c0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(model_id, test_items, max_samples=100):\n",
    "    y_true, y_pred = [], []\n",
    "    for i, item in enumerate(test_items[:max_samples]):\n",
    "        prompt = build_prompt(item)\n",
    "        response = client.chat.completions.create(\n",
    "            model=model_id,\n",
    "            messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "            temperature=0\n",
    "        )\n",
    "        output = response.choices[0].message.content\n",
    "        try:\n",
    "            pred_price = float(output.split(\"$\")[1].split()[0])\n",
    "        except:\n",
    "            continue\n",
    "        y_true.append(item.price)\n",
    "        y_pred.append(pred_price)\n",
    "        print(f\"{i+1}. {item.title[:50]} | Actual: ${item.price} | Pred: ${pred_price}\")\n",
    "    return y_true, y_pred\n",
    "\n",
    "y_true, y_pred = test_model(MODEL_ID, val_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "347c22cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "errors = np.abs(np.array(y_true) - np.array(y_pred))\n",
    "colors = [\"green\" if e < 10 else \"orange\" if e < 25 else \"red\" for e in errors]\n",
    "\n",
    "plt.figure(figsize=(10,6))\n",
    "plt.scatter(range(len(y_true)), y_true, color='blue', label='Actual', alpha=0.6)\n",
    "plt.scatter(range(len(y_pred)), y_pred, color=colors, label='Predicted', alpha=0.8)\n",
    "plt.title(\"Fine-tuned Price Prediction Performance (Color-Coded by Error)\")\n",
    "plt.xlabel(\"Sample Index\")\n",
    "plt.ylabel(\"Price ($)\")\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "avg_error = np.mean(errors)\n",
    "print(f\"\\nAverage error: ${avg_error:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-engineering",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
